import tensorflow as tf
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from empiricalVolterra import empirical_Volterra
from SGDMomentum import run_momentum_batch
from sklearn.preprocessing import StandardScaler


'''
----------------------------------------
Code to generate figure 3 in the paper.
----------------------------------------
'''

d = 28*28
n = 60000

#load mnist and transform to even/odd problem
(x_train, y_train), (x_test,y_test) = tf.keras.datasets.mnist.load_data()
flat = x_train.reshape(60000,28*28)
scaler = StandardScaler().fit(flat)
rescaled_flat =scaler.transform(flat)


A = rescaled_flat[0:n]
b = (y_train[0:n] % 2 - 0.5)  #odd/even => +0.5/-0.5

invnorms=sp.sparse.diags(1/np.linalg.norm(A,axis=1))
A = invnorms@A #row norms are 1

_, S, _ = np.linalg.svd(A, full_matrices = False)
eiglist = S**2
eiglist = np.pad(eiglist,(0,n-d))
x_0 = np.zeros(d)

#----------- MNIST Experiments -----------
Delta = 0.8
learning_rate = 0.001
beta=30000
R = 11000
R_tilde = 5300
max_iter = 50

num_trials = 10
mnist_losscurves = []
x_0 = np.zeros(d)

for j in range(num_trials):
  (_, sgd_losscurve) = run_momentum_batch(A=A, b=b,x=x_0,n=n,d=d,max_iter=max_iter, batch_size = beta, learning_rate=learning_rate,
                        Delta=Delta, loss_history = [])
  mnist_losscurves.append(sgd_losscurve)

mnist_losscurves = np.array(mnist_losscurves)

cs = plt.get_cmap("viridis")
plt.figure(figsize = (18,10.0))
plt.yscale("log")
plt.ylabel("Function Values", fontsize = '40')
plt.xlabel("Iterations", fontsize = '40')

empiricalVolterra_values = empirical_Volterra(n = n, max_iter=max_iter, gamma=learning_rate, eiglist= eiglist,
                                                zeta = np.float64(beta)/np.float64(n), Delta=Delta, R=R, R_tilde=R_tilde)
plt.plot(empiricalVolterra_values,c = "r", label = "Empirical Volterra " + r"$(\zeta = $" + str(np.float64(beta/n)) + ")")

lq = np.quantile(mnist_losscurves, 0.1, axis = 0)
hq = np.quantile(mnist_losscurves, .90, axis = 0)


plt.fill_between(np.arange(max_iter), lq,
                 hq, facecolor=cs(0.1), label = "SGD+M " + r"$(\zeta = $" + str(np.float64(beta/n)) + ")")



plt.legend(loc = "upper right", fontsize='30')
plt.rc('xtick',labelsize=30)
plt.rc('ytick',labelsize=30)

plt.savefig("MNISTMidBatch.pdf", transparent = True)
plt.show()
